#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (c) 2011 - 2013 Stefano Mazzucco <stefano -at- curso.re>
# All rights reserved.
#
# This file is part of Crystal Ball Plus.
#
# Crystal Ball Plus is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Crystal Ball Plus is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Crystal Ball Plus.  If not, see <http://www.gnu.org/licenses/>.

"""Mathematical functions for crystallography. Linear algebra, rather
than explicit functions, is used.

"""

import numpy as np
from utilities import remove_duplicates

def _gcd(a, b):
    """Return the greatest common divisor of a and b

    """
    # borrowed from numpy:
    # https://github.com/numpy/numpy/blob/f9c7bde68655de06546f5c8dba30314888e409cc/numpy/core/_internal.py
    while b:
        a, b = b, a%b
    return a

def gcd(*arg):
    """Return the greatest common divisor of a list of arguments.

    Example:
       >>> gcd(126, 42, 945)
       21

    """
    if len(arg) == 2:
        return _gcd(arg[0], arg[1])
    else:
        return gcd(_gcd(arg[0], arg[1]), *arg[2:])

def _lcm(a, b):
    """Return the least common multiple of a and b

    """
    return a * b / _gcd(a, b)

def lcm(*arg):
    """Return the least common multiple of a list of arguments.

    Example:
       >>> lcm(18, 9, 27)
       54

    """
    if len(arg) == 2:
        return _lcm(arg[0], arg[1])
    else:
        return lcm(_lcm(arg[0], arg[1]), *arg[2:])

def direct_metric_tensor(uc, deg=True, zero=None):
    """Return the direct metric tensor for the lattice defined by
    the given unit cell 'uc'.

    *Parameters*

    uc : array of 6 elements
         Unit cell: (a, b, c, alpha, beta, gamma)

    deg : boolean (optional)
          Whether the angles in the unit cell are degrees (True, default)
          or radians (False)

    zero : float (optional)
           Value below which any tensor element is considered to be zero.
           If zero=None (default), no approximation will be applied.

    *Returns*

    dmt : 3x3 numpy array
          direct metric tensor (symmetric matrix)

    """
    a = uc[0]
    b = uc[1]
    c = uc[2]
    alpha = uc[3]
    beta = uc[4]
    gamma = uc[5]
    if deg:
        alpha = np.deg2rad(alpha)
        beta = np.deg2rad(beta)
        gamma = np.deg2rad(gamma)

    dmt = np.array([
        [ a ** 2, a * b * np.cos(gamma), a * c * np.cos(beta)],
        [ 0, b ** 2, b * c * np.cos(alpha)],
        [ 0, 0, c ** 2]
        ])
    if zero is not None:
        msk = (dmt > zero)            # Elements that are greater than 'zero'.
        dmt = np.where(msk, dmt, 0.0) # Put 0.0 where the elements are <= 'zero'.
    idxl = np.tril_indices(3, -1)     # Indices of the lower half.
    idxu = np.triu_indices(3, 1)      # Indices of the upper half.
    dmt[idxl] = dmt[idxu]             # The matrix is symmetric.
    assert np.allclose(dmt, dmt.T), 'metric tensor is not symmetric:\n %s' % dmt
    return dmt

def reciprocal_metric_tensor(uc, deg=True, zero=None):
    """Return the reciprocal metric tensor for the lattice defined by
    the given unit cell 'uc'.

    Note that the reciprocal metric tensor is the inverse of the direct
    metric tensor!

    Also note that the ROWS of the reciprocal metric tensor represent
    the reciprocal basis in terms of the direct basis.
    I.e.
    a(*) = rmt00 a + rmt01 b + rmt02 c
    b(*) = rmt10 a + rmt11 b + rmt12 c
    c(*) = rmt20 a + rmt21 b + rmt22 c
    Where a, b, and c are the basis of the direct space.

    *Parameters*

    uc : array of 6 elements
         Unit cell: (a, b, c, alpha, beta, gamma)

    deg : boolean (optional)
          Whether the angles in the unit cell are degrees (True, default)
          or radians (False)

    zero : float (optional)
           Value below which any tensor element is considered to be zero.
           If zero=None (default), no approximation will be applied

    *Returns*

    rmt : 3x3 numpy array
          reciprocal metric tensor (symmetric matrix)

    """
    dmt = direct_metric_tensor(uc, deg, zero)
    rmt =  np.linalg.inv(dmt)
    I = np.eye(3)
    assert(np.allclose(np.dot(dmt, rmt), I)), 'Inversion failed!'
    return rmt

def unitcell(mt, deg=True):
    """Return a unit cell built from the given metric tensor.

    *Parameters*

    mt : 3x3 numpy array
         metric tensor (must be a symmetric matrix)

    deg : boolean (optional)
          Whether the angles in the unit cell are degrees (True, default)
          or radians (False)

    *Returns*

    uc : array of 6 elements
         Unit cell: (a, b, c, alpha, beta, gamma)
         if deg=True (default) the alpha, beta, and gamma will be in degrees,
         if deg=False they will be in radians

    """
    assert(np.allclose(mt, mt.T)), 'input metric tensor is not symmetric:\n %s' % mt
    # the ROWS (or the COLUMNS, since the m.t. is symmetric)
    # of the metric tensor contain the components of
    # the basis expressed in terms of the INVERSE basis
    v_a = mt[:, 0]
    v_b = mt[:, 1]
    v_c = mt[:, 2]
    # calculate norms and angles in the INVERSE space
    rmt = np.linalg.inv(mt)
    assert(np.allclose(rmt, rmt.T)), 'inverse metric tensor is not symmetric:\n %s' % mt
    I = np.eye(3)
    assert(np.allclose(np.dot(mt, rmt), I)), 'Inversion failed!'
    a = ncnorm(v_a, rmt)
    b = ncnorm(v_b, rmt)
    c = ncnorm(v_c, rmt)
    alpha = theta(v_b, v_c, rmt, deg)
    beta = theta(v_c, v_a, rmt, deg)
    gamma = theta(v_a, v_b, rmt, deg)
    return np.array((a, b, c, alpha, beta, gamma))

def unitcell_volume(uc, deg=True):
    """Return the volume of the unit cell 'uc'.
    Rather than using the general formula for a triclinic unit cell:

    vol =  a * b * c *
          (1 - cos(alpha) ** 2 - cos(beta) ** 2 - cos(gamma) ** 2)
          + 2 * sqrt(cos(alpha) * cos(beta) * cos(gamma))

    it calculates the volume as the squared root of the determinant of
    the direct metric tensor of the unit cell.

    *Parameters*

    uc : array of 6 elements
         Unit cell: (a, b, c, alpha, beta, gamma)

    deg : boolean (optional)
          Whether the angles in the unit cell are degrees (True, default)
          or radians (False)

    *Returns*

    ucvol : float
            volume of the unit cell

    """
    dmt = direct_metric_tensor(uc, deg)
    volsq = np.linalg.det(dmt)
    assert(volsq > 0), 'Squared volume is negative: %s' % volsq
    return np.sqrt(volsq)

def unitcell_volume_explicit(uc, deg=True):
    """Return the volume of the unit cell 'uc'.
    Use the explicit formula for a triclinic unit cell:

    vol =  a * b * c *
          (1 - cos(alpha) ** 2 - cos(beta) ** 2 - cos(gamma) ** 2)
          + 2 * sqrt(cos(alpha) * cos(beta) * cos(gamma))

    *Parameters*

    uc : array of 6 elements
         Unit cell: (a, b, c, alpha, beta, gamma)

    deg : boolean (optional)
          Whether the angles in the unit cell are degrees (True, default)
          or radians (False)

    *Returns*

    ucvol : float
            volume of the unit cell

    """
    a = uc[0]
    b = uc[1]
    c = uc[2]
    alpha = uc[3]
    beta = uc[4]
    gamma = uc[5]
    if deg:
        alpha = np.deg2rad(alpha)
        beta = np.deg2rad(beta)
        gamma = np.deg2rad(gamma)
    a = 1 - np.cos(alpha) ** 2 - np.cos(beta) ** 2 - np.cos(gamma) ** 2 \
        + 2 * np.cos(alpha) * np.cos(beta) * np.cos(gamma)
    assert(a > 0), 'Square root has negative argument: %s' % a
    vol =  a * b * c * np.sqrt(a)
    return vol

def nccross(p, q, mt):
    """Return the non-Cartesian cross product between two vectors p and q
    using the given metric tensor mt.

    The componets of the cross product are expressed in the INVERSE basis.

    *Parameters*

    p : numpy array of 3 elements

    q : numpy array of 3 elements

    mt : 3x3 numpy array
         metric tensor. It can be e.g. the direct or reciprocal metric tensor

    *Returns*

    c : numpy array of 3 elements
        cross product between vectors p and q; the components are expressed
        in terms of the INVERSE basis

    """
    uc = unitcell(mt)
    omega = unitcell_volume(uc)
    return omega * np.cross(p, q)

def ncdot(p, q, mt):
    """Return the non-Cartesian dot product between two vectors p and q
    using the given metric tensor 'mt'.

    Note that this is valid in every rectilinear and curvilinear coordinate
    frames and, thus in every crystal system.

    *Parameters*

    p : numpy array of 3 elements

    q : numpy array of 3 elements

    mt : 3x3 numpy array
         metric tensor. It can be e.g. the direct or reciprocal metric tensor

    *Returns*

    d : float
        distance between vectors p and q

    """
    # d = p(3-row) * dmt(3x3) * q(3-column)
    # You can treat rank-1 arrays as either row or column vectors. dot(A,v)
    # treats v as a column vector, while dot(v,A) treats v as a row vector.
    # This can save you having to type a lot of transposes.
    return np.dot(p, np.dot(mt, q))

def ncnorm(p, mt):
    """Return the non-Cartesian norm between two vectors p and q
    using the given metric tensor mt.

    The non-Cartesian norm is simply the square root of ncdot(p, p, mt).

    *Parameters*

    p : numpy array of 3 elements

    mt : 3x3 numpy array
         metric tensor. It can be e.g. the direct or reciprocal metric tensor

    *Returns*

    norm : float
           non-Cartesian norm of vector p

    """
    nrm = np.dot(p, np.dot(mt, p))
    assert(nrm > 0), 'The square of the norm (%g) is negative!' % nrm
    return np.sqrt(nrm)

def ncdist(a, b, mt):
    """Return the distance between two points a and b
    using the given metric tensor mt.

    The distance is defined as the ncnorm of the vector connecting a and b.

    *Parameters*

    a : numpy array of 3 elements

    b : numpy array of 3 elements

    mt : 3x3 numpy array
         metric tensor. It can be e.g. the direct or reciprocal metric tensor

    *Returns*

    d : float
        distance between points a and b

    """
    v = np.asarray(a) - np.asarray(b) # coordinates of vector connecting a and b
    return ncnorm(v, mt)

def theta(p, q, mt, deg=True):
    """Return the angle between two vectors p and q
    using the given metric tensor mt.

    *Parameters*

    p : numpy array of 3 elements

    q : numpy array of 3 elements

    mt : 3x3 numpy array
         metric tensor. It can be e.g. the direct or reciprocal metric tensor

    deg : boolean
          whether the angles in the unit cell are degrees (True, default)
          or radians (False)

    *Returns*

    theta : float
            angle between vectors p and q
            if deg=True (default) the units will be degrees,
            if deg=False the units will be radians

    """
    prd = ncdot(p, q, mt)
    normp = ncnorm(p, mt)
    normq = ncnorm(q, mt)
    costheta =  prd / (normp * normq)
    # fix floating point
    if np.allclose(costheta, 1.0):
        costheta = 1.0
    elif np.allclose(costheta, -1.0):
        costheta = -1.0
    assert(costheta <= 1.0), 'Wrong cosine value %e' % costheta
    assert(costheta >= -1.0), 'Wrong cosine value %e' % costheta
    theta = np.arccos(costheta)
    if deg:
        theta = np.rad2deg(theta)
    return theta

def d_spacing(hkl, uc, deg=True, zero=None):
    """Return the interplanar spacing between lattice planes determined by
    the Miller indices hkl. The crystal structure is given by the unit cell uc

    This is also the inverse of the length of the reciprocal lattice vector
    with components (h, k, L)

    *Parameters*

    hkl : array of 3 elements
          the Miller indices of a lattice plane

    uc : array of 6 elements
         unit cell: (a, b, c, alpha, beta, gamma)

    deg : boolean (optional)
          Whether the angles in the unit cell are degrees (True, default)
          or radians (False)

    zero : float (optional)
           Value below which any tensor element is considered to be zero.
           If zero=None (default), no approximation will be applied

    *Returns*

    d : float
        interplanar spacing

    """
    rmt = reciprocal_metric_tensor(uc, deg, zero)
    return 1. / ncnorm(hkl, rmt)

def zone_axis(hkl1, hkl2, deg=True):
    """Return the zone axis [u, v, w] relative to the set of crystallographic
    planes determined by the Miller indices hkl1 and hkl2.

    *Parameters*


    hkl1 : array of 3 elements
          Miller indices (h1, k1, L1)

    hkl2 : array of 3 elements
           Miller indices (h2, k2, L2)

    deg : boolean (optional)
          Whether the angles in the unit cell are degrees (True, default)
          or radians (False)

    *Returns*


    uvw : array of 3 elements
          zone axis (u, v, w)

    """
    hkl1_a, hkl2_a = np.asarray(hkl1), np.asarray(hkl2)
    uvw = np.cross(hkl1_a, hkl2_a)
    k = gcd(uvw[0], uvw[1], uvw[2])
    uvw = uvw / k
    assert(np.dot(hkl1_a, uvw) == 0), 'Zone axis %s does not belong to %s' % (uvw, hkl1_a)
    assert(np.dot(hkl2_a, uvw) == 0), 'Zone axis %s does not belong to %s' % (uvw, hkl2_a)
    return uvw

def plane2mill(a):
    """Return the Miller indices corresponding to the plane determined
    by vector 'a'.

    Due to floating point issues(*), 'a' must be input as list of STRINGS,
    these will then be converted in instances of the built-in class
    Fratcion

    (*) http://docs.python.org/tutorial/floatingpoint.html#tut-fp-issues

    The components of 'a' should be units of the basis vector's lenghts
    of the unit cell, e.g. a = ('1/2', '3', '1'), the string 'inf'
    (infinity) must be provided when the plane is parallel to a given direction.
    The plane cannot intercept the origin, i.e. neither component can be zero.

    Note that planes parallel with each other have the same Miller indices.

    *Parameters*

    a : tuple of 3 strings

    *Returns*

    m : array of 3 elements
        Miller indices (hkl)

    """
    if len(a) != 3:
        err = 'Input parameter must have 3 components, it has %s' % len(a)
        raise ValueError(err)
    if '0' in a:
        err = 'The plane intercepts the origin: %s' % a
        raise ValueError(err)
    from fractions import Fraction
    a = np.asarray(a)
    m = []
    for i in a:
        if i == 'inf':
            m.append(Fraction('0'))
        else:
            m.append(1 / Fraction(i))
    k = lcm(m[0].denominator, m[1].denominator, m[2].denominator)
    m = [m[i] * k for i in xrange(len(m))]
    m = [m[i].numerator for i in xrange(len(m))]
    return np.asarray(m)

def family(v, sg, space='d', uc=None, deg=True, rnd=False):
    """Given the vector 'v', return the set of equivalent directions (family)
    in the crystal system defined by the space group 'sg'.

    Note that if you want the family of a Miller plane you should set the
    keyword 'space' to 'r' (reciprocal) and specify the unit cell.

    The componets of the vector MUST be integers, and integer (within the
    floating point approximation) componets should be returned as well.

    One can set the kyword  'rnd' to 'True' to round the result. That might
    be useful when working in the reciprocal space.

    *Parameters*

    v : array of 3 elements

    sg : instance of SpaceGroup

    space : string (optional)
            whether 'v' is defined in the direct space (space='d', default)
            or in the reciprocal space (space='r')

    uc : array of 6 elements (optional)
         Unit cell: (a, b, c, alpha, beta, gamma)

    deg : boolean (optional)
          Whether the angles in the unit cell are degrees (True, default)
          or radians (False)

    rnd : boolean (optional)
          whether to round the result using numpy.around (rnd=False, default)
          or not (rnd=True)
          NOTE: if space='r', rnd is set to 'True'

    *Returns*

    family : list of arrays

    """
    if space not in ('d', 'r'):
        raise ValueError("keyword 'space' MUST be 'd' or 'r', you set '%s'."
                         % space)
    family = []
    if space == 'r':
        rnd = True              # reciprocal space needs rounding
        if uc is None:
            raise ValueError("keyword 'uc' (unit cell) not set")
        dmt = direct_metric_tensor(uc, deg)
        rmt = reciprocal_metric_tensor(uc, deg)
        a =  np.dot(rmt, v)     # transform to direct space
    else:
        a = np.asarray(v)
    if hasattr(sg, 'iter_symops'):
        for symop in sg.iter_symops():
            e = np.dot(symop.R, a) # apply point group matrix (rotation)
            if space == 'r':
                e = np.dot(dmt, e) # transform back to reciprocal space
            if rnd:
                e = np.around(e)
            family.append(e)
    else:
        raise TypeError("invalid parameter 'sg' (space group)")
    remove_duplicates(family)
    return family

def switch_space(v, mt):
    """Return a vector whose componets correspond to v transformed
    in the space defined by the metric tensor mt

    This is achieved by using the dot product between v and mt

    This is useful e.g. for transforming a vector from the direct to the
    reciprocal space and vice versa:

    *Parameters*

    v : numpy array of 3 elements

    mt : 3x3 numpy array
         metric tensor. It can be e.g. the direct or reciprocal metric tensor

    *Returns*

    r : numpy array of 3 elements
        transformed vector in the space defined by mt
    """
    rmt = np.linalg.inv(mt)
    I = np.eye(3)
    assert(np.allclose(np.dot(mt, rmt), I)), 'Inversion failed!'
    return np.dot(v, rmt)

def switch_uc(uc, deg=True, rnd=None):
    """Return the transformed unit cell in the inverse metric space.
    This allows one to compute the reciprocal unit cell from the (direct)
    unit cell and vice versa.

    Note that the result will likely be approximate due to floating point
    issues(*), e.g. you will have '59.999999999999943' instead of '60.0'
    You can avoid that by setting the rnd to the number of decimals that
    you want to round. However, you will likely suffer of accuracy issues.

    (*) http://docs.python.org/tutorial/floatingpoint.html#tut-fp-issues

    *Parameters*

    uc : array of 6 elements
         Unit cell: (a, b, c, alpha, beta, gamma)

    deg : boolean (optional)
          Whether the angles in the unit cell are degrees (True, default)
          or radians (False)

    rnd : integer
          number of decimals used to round the result

    *Returns*


    uc_t : array of 6 elements
           transformed unit cell: (a_t, b_t, c_t, alpha_t, beta_t, gamma_t)
           It has the SAME units as 'uc'.

    """
    mt = direct_metric_tensor(uc, deg)
    rmt = reciprocal_metric_tensor(uc, deg)
    I = np.eye(3)
    assert(np.allclose(np.dot(mt, rmt), I)), 'Inversion failed!'
    uc_t = unitcell(rmt)
    if rnd is not None:
        print('Rounding result to %i decimals. You will likely suffer of accuracy issues' % rnd)
        uc_t = np.around(uc_t, decimals=rnd)
    return uc_t

def cartesian(w, uc, deg=True):
    """Return the coordinates of vector w - expressed in the non-Cartesian
    reference frame definded by the unit cell uc - in a Cartesian reference
    frame.

    The Cartesian reference frame basis vectors e1, e2, adn e3 are chosen
    as follows::

        e1 = a1 / |a1|
        e2 = e3 x e1
        e3 = astar3 / |astar3|

    where a1 is the first vector basis of the direct lattice, astar3 is the
    third basis of the reciprocal lattice, 'x' denotes the cross product,
    and `|.|` denotes the norm.

    That means that e1 is parallel to a1 and e3 is parallel to astar3.

    Relationship between Cartesian and non-Cartesian reference frames::

         `a3
          `   |e3//astar3
           `  |
            ` |_________ e2
             /\
            /  \
           /    \ a2
          /
         /e1//a1


    *Parameters*

    w : array of 3 elements
        coordinates of vector a in non-Cartesian reference frame

    uc : array of 6 elements
         Unit cell: (a, b, c, alpha, beta, gamma) that defines the non-Cartesian
         reference frame

    deg : boolean (optional)
          Whether the angles in the unit cell are degrees (True, default)
          or radians (False)

    *Returns*

    w_C : array of 3 elements
        coordinates of vector a in a Cartesian reference frame
    """
    dmt = direct_metric_tensor(uc, deg)
    rmt = np.linalg.inv(dmt)
    I = np.eye(3)
    assert(np.allclose(np.dot(dmt, rmt), I)), 'Inversion failed!'
    v_w = np.asarray(w)
    ucvol = unitcell_volume(uc, deg)
    sqrt = np.sqrt
    L1 = dmt[0, :] / sqrt(dmt[0, 0])
    L2 = np.array([0,
                   ucvol * sqrt(rmt[2, 2] / dmt[0, 0]),
                   -1 * ucvol * rmt[2, 1] * sqrt(rmt[2, 2] * dmt[0, 0])
                   ])
    L3 = np.array([0, 0, 1 / sqrt(rmt[2, 2])])
    tm = np.asarray([L1, L2, L3]) # transformation matrix
    return np.dot(tm, v_w)

def minarr(p, q):
    """Compare arrays 'p' and 'q' element-wise and return the array that has
    most minimum elements.

    .. warning::
       This function is not reliable, use at your own risk!

    *Parameters*

    p : 1-D array
    q : 1-D array

    *Returns*

    either p or q, whichever has most minimum elements

    """
    a = np.asarray(p)
    b = np.asarray(q)
    check = (a < b)
    is_a = check.sum()
    is_b = len(check) - is_a
    if is_a > is_b:
        return p
    else:
        return q

def mymin(a):
    """Return the minimum element of list 'a'

    """
    L = list(a)
    while len(L)>1:
        R = []
        i = L.pop(0)
        for j in L:
            if i < j:
                R.append(i)
            else:
                R.append(j)
        L = R[:]
    return L[0]

def minlistarr(a):
    """Return the  the minimum element of list 'a', where
    'a' is a list of arrays and the minimization algorithm is
    'minrarr'

    This function is not reliable, use at your own risk!

    """
    L = list(a)
    while len(L)>1:
        R = []
        i = L.pop(0)
        for j in L:
            R.append(minarr(i,j))
        L = R[:]
    return L[0]

def min_err(dic, avg=True):
    """Return the elements that have most of the minimum errors.

    *Parameters*

    min_dic : dictionary of dictionaries
              min_dic[key1][key2]
              key1: input file
              key2: reference file
              value: N-tuple, where
                  value[0] : structured array
                             at least one key MUST be 'err'

    avg : bool (optional)
          it True (default) will chose value that have the minimum
          average error. If False, it will attempt to chose the value
          that has the most of minimum elements (this is buggy)

    """
    if not avg:
        print('\t\t### WARNING ###\t\t')
        print('This function is not reliable, use at your own risk!')
        print('\t\t###############\t\t')
    merr = {}
    for inpfile in dic:
        merr[inpfile] = {}
        keys = dic[inpfile].keys()
        values = dic[inpfile].values()
        errs = []
        for i in xrange(len(values)):
            if avg:
                errs.append(np.average(values[i][0]['err']))
            else:
                errs.append(list(values[i][0]['err']))
        if avg:
            if errs:
                idx = np.argmin(errs)
            else:
                break
        else:
            errs2 = errs[:]
            while len(errs)>1:
                tmp = []
                i = np.asarray(errs.pop(0))
                for u in xrange(len(errs)):
                    j = np.asarray(errs[u])
                    check = (i < j)
                    is_i = check.sum()
                    is_j = len(check) - is_i
                    if is_i > is_j:
                        tmp.append(i)
                    else:
                        tmp.append(j)
                errs = tmp[:]
            err =  errs[0]
            idx = []
            for h in xrange(len(errs2)):
                if np.all(np.asarray(err) == np.asarray(errs2[h])):
                    idx.append(h)
            assert(len(idx) == 1),'Problem while finding index of minimum error'
            idx = idx[0]
        merr[inpfile][keys[idx]] =  values[idx]
    return merr
